﻿using DotNetCommon;
using DotNetCommon.Logger;
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace DBUtil.Provider.MySql
{
    public class MySqlLocker : ILocker
    {
        private static ILogger<MySqlLocker> logger = LoggerFactory.CreateLogger<MySqlLocker>();
        private static ConcurrentDictionary<string, SemaphoreSlim> _lockSemaphores = new ConcurrentDictionary<string, SemaphoreSlim>();

        private async Task<T> RunInLockInternalAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond)
        {
            //先单机拦截并发
            var now = DateTime.Now;
            var asyncLock = _lockSemaphores.GetOrAdd(lock_str, _ => new SemaphoreSlim(1, 1));
            var b = await asyncLock.WaitAsync(getLockTimeoutSecond * 1000);
            if (!b) throw new Exception($"获取锁[{lock_str}]超时,单机内超时.");
            try
            {
                //重新创建一个 用来隔离session
                var tmpDb = DBFactory.CreateDB(db.DBType, db.DBConn);
                //单机通过后,再数据库锁内拦截
                var now2 = DateTime.Now;
                var timeout2 = getLockTimeoutSecond - (int)(now2 - now).TotalSeconds;
                if (timeout2 < 100) timeout2 = 100;
                lock_str = "dbutil:runinlock:" + lock_str;
                lock_str = tmpDb.EscapeString(lock_str);
                var sql = $"select GET_LOCK('{lock_str}',{timeout2})";
                return await tmpDb.RunInSessionAsync(async () =>
                {
                    var res = await tmpDb.SelectScalarAsync<int>(sql);
                    if (res == 1)
                    {
                        //获取到锁
                        T result = default;
                        try
                        {
                            result = await func();
                        }
                        finally
                        {
                            try
                            {
                                await tmpDb.ExecuteSqlAsync($@"select RELEASE_LOCK('{lock_str}')");
                            }
                            catch (Exception ex)
                            {
                                logger.LogError($"尝试释放基于数据库的分布式锁失败(lock_str={lock_str}),异常信息:{ex?.Message}");
                            }
                        }
                        return result;
                    }
                    else
                    {
                        throw new Exception($"数据库获取锁失败: GET_LOCK 返回:{res}");
                    }
                });
            }
            finally
            {
                try { asyncLock.Release(); } catch { }
            }
        }

        public void RunInLock(DBAccess db, string lock_str, Action action, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(action, nameof(action));
            var task = RunInLockInternalAsync(db, lock_str, async () =>
            {
                action();
                return Task.CompletedTask;
            }, getLockTimeoutSecond);
            task.Wait();
        }

        public T RunInLock<T>(DBAccess db, string lock_str, Func<T> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            var task = RunInLockInternalAsync(db, lock_str, () =>
            {
                var res = func();
                return Task.FromResult(res);
            }, getLockTimeoutSecond);
            return task.Result;
        }
        public async Task RunInLockAsync(DBAccess db, string lock_str, Func<Task> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            await RunInLockInternalAsync(db, lock_str, async () =>
            {
                await func();
                return Task.CompletedTask;
            }, getLockTimeoutSecond);
        }

        public async Task<T> RunInLockAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            return await db.RunInNoSessionAsync(async () => await RunInLockInternalAsync(db, lock_str, func, getLockTimeoutSecond));
        }
    }
}
